"""
graph_utils.weights
------------------

Efficient computation of edge weights for kNN graphs, with support for various
kernel rules and symmetrization schemes.

Includes:
- Main numba-accelerated vectorized weight computation.
- sigma_eff helper for custom kernel rules.
"""

import math
import numpy as np
# The weight computation can be accelerated using Numba.  When Numba is
# unavailable at import time we fall back to pure Python implementations.
try:
    import numba  # type: ignore[import]
    _HAS_NUMBA = True
except Exception:
    numba = None  # type: ignore[assignment]
    _HAS_NUMBA = False

if _HAS_NUMBA:
    @numba.njit(inline='always')  # type: ignore[misc]
    def sigma_eff(sym_code: int, si: float, sj: float) -> float:
        """
        Select effective sigma based on symmetrization code.

        sym_code: 0=mean, 1=max, 2=umap, 3=geom, 4=min, 5=harm
        """
        if sym_code == 1:      # max
            return si if si >= sj else sj
        elif sym_code == 3:    # geom
            return (si * sj) ** 0.5
        else:                  # mean/umap/min/harm → use sigma_i
            return si
else:
    def sigma_eff(sym_code: int, si: float, sj: float) -> float:
        """Pure Python fallback for sigma_eff."""
        if sym_code == 1:
            return si if si >= sj else sj
        elif sym_code == 3:
            return (si * sj) ** 0.5
        else:
            return si
    
if _HAS_NUMBA:
    @numba.njit(parallel=True, fastmath=True, cache=True)  # type: ignore[misc]
    def compute_weights_vectorized(
        neigh_idx: np.ndarray,
        neigh_dist: np.ndarray,
        rho: np.ndarray,
        sigmas: np.ndarray,
        sym_code: int,
        n_samples: int,
        k: int,
        kernel_function,            # numba‑compiled kernel(r, params) -> value
        kernel_params: np.ndarray,  # float32[:]
    ):
        """
        Compute directed weights P_ij for a kNN graph using Numba.

        See the docstring in the pure Python fallback for parameter
        descriptions.
        """
        total_edges = n_samples * (k - 1)
        rows = np.empty(total_edges, dtype=np.int32)
        cols = np.empty(total_edges, dtype=np.int32)
        vals = np.empty(total_edges, dtype=np.float32)

        for i in numba.prange(n_samples):
            rho_i = rho[i]
            sigma_i = sigmas[i]
            start = i * (k - 1)

            for j in range(k - 1):
                nbr = neigh_idx[i, j]
                sigma_j = sigmas[nbr]

                diff = neigh_dist[i, j] - rho_i
                if diff < 0.0:
                    diff = 0.0

                # sym_code: 0=mean, 1=max, 2=umap, 3=geom, 4=min, 5=harm, 6=sinkhorn(label only)
                if sym_code == 1:            # max
                    sigma_eff_val = sigma_i if sigma_i > sigma_j else sigma_j
                elif sym_code == 3:          # geom
                    sigma_eff_val = math.sqrt(sigma_i * sigma_j)
                else:                        # mean/umap/min/harm → use sigma_i as in original
                    sigma_eff_val = sigma_i

                idx = start + j
                rows[idx] = i
                cols[idx] = nbr
                vals[idx] = kernel_function(diff / sigma_eff_val, kernel_params)

        return rows, cols, vals
else:
    def compute_weights_vectorized(
        neigh_idx: np.ndarray,
        neigh_dist: np.ndarray,
        rho: np.ndarray,
        sigmas: np.ndarray,
        sym_code: int,
        n_samples: int,
        k: int,
        kernel_function: callable,
        kernel_params: np.ndarray,
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Compute directed weights P_ij for a kNN graph.

        This pure Python implementation does not require Numba.  It may
        be slower but allows the package to function when Numba is not
        installed.

        Parameters
        ----------
        neigh_idx : np.ndarray of shape (n_samples, k-1)
            Indices of nearest neighbours.
        neigh_dist : np.ndarray of shape (n_samples, k-1)
            Distances to nearest neighbours.
        rho : np.ndarray of shape (n_samples,)
            Local connectivity offsets.
        sigmas : np.ndarray of shape (n_samples,)
            Per‑sample kernel bandwidths.
        sym_code : int
            Code for sigma fusion rule (0=mean, 1=max, 2=umap, 3=geom, 4=min, 5=harm).
        n_samples : int
            Number of samples in the dataset.
        k : int
            Number of neighbours per sample (including self).
        kernel_function : callable
            Kernel function taking a scalar argument and kernel parameters.
        kernel_params : np.ndarray
            Additional parameters passed to the kernel function.

        Returns
        -------
        rows, cols, vals : np.ndarray
            Directed edge indices and weights.  Each has length
            ``n_samples * (k - 1)``.
        """
        total_edges = n_samples * (k - 1)
        rows = np.empty(total_edges, dtype=np.int32)
        cols = np.empty(total_edges, dtype=np.int32)
        vals = np.empty(total_edges, dtype=np.float32)

        idx = 0
        for i in range(n_samples):
            rho_i = rho[i]
            sigma_i = sigmas[i]
            for j in range(k - 1):
                nbr = int(neigh_idx[i, j])
                sigma_j = sigmas[nbr]

                diff = neigh_dist[i, j] - rho_i
                if diff < 0.0:
                    diff = 0.0

                # Determine effective sigma
                if sym_code == 1:  # max
                    sigma_eff_val = sigma_i if sigma_i > sigma_j else sigma_j
                elif sym_code == 3:  # geom
                    sigma_eff_val = (sigma_i * sigma_j) ** 0.5
                else:
                    sigma_eff_val = sigma_i

                rows[idx] = i
                cols[idx] = nbr
                vals[idx] = kernel_function(diff / sigma_eff_val, kernel_params)
                idx += 1

        return rows, cols, vals

__all__ = ["sigma_eff", "compute_weights_vectorized"]